import numpy as np
import cupy as cp
import scipy
from scipy.optimize import linear_sum_assignment

# Import TensorLy
import tensorly as tl
from tensorly.decomposition import symmetric_parafac_power_iteration as sym_parafac

import time
import csv
import random
import sys

#Insert Plotly
import matplotlib.pyplot as plt
import pickle
# Import TLDA and baselines
from version0_99.tlda_wrapper import TLDA
from version0_99.test_tlda import get_mu
import version0_15.tensor_lda_clean as tlda_mid

backend = "numpy"
tl.set_backend(backend)

VOCAB = 500 # 1000

def validate_gammad (gammad_arr, theta_arr, transpose = False, num_tops=3,smoothing=1e-6):
    tl.set_backend('numpy')

    '''get RMSE for topic distribution using heuristic'''
    factor = tl.tensor(cp.asnumpy(gammad_arr))

    factor[factor < 0.] = 0.
    # smooth beta
    factor *= (1. - smoothing)
    factor += (smoothing / factor.shape[1])
    factor /= factor.sum(axis=0)

    factor =  (factor.transpose(0, 1) / tl.norm(factor, axis=1)[:, None]).T
    factor[np.isnan(factor)] = 0
    sample = tl.tensor(cp.asnumpy(theta_arr))
    sample = (sample.transpose(0, 1) / tl.norm(sample, axis=1)[:, None]).T
    sample[np.isnan(sample)] = 0

    if transpose == False:
        tl.set_backend('numpy')
        M_corr = tl.dot(factor.T, sample)
    else:
        tl.set_backend('numpy')
        M_corr = tl.dot(factor, sample.T)
    permutation = linear_sum_assignment(-M_corr)

    if (transpose == True):
        sample = sample.T
        return permutation, tl.metrics.regression.RMSE(tl.tensor(np.array([theta_arr[:, permutation[1][i]] for i in range(num_tops)])), tl.tensor(gammad_arr.T))
    return permutation, tl.metrics.regression.RMSE(tl.tensor(np.array([sample[:, permutation[1][i]] for i in range(num_tops)])), factor.T)


def loss_rec(factor, cumulant, theta):
    # cumulant = M3 - compute this
    rec = tl.cp_to_tensor((None, [factor]*3))
    rec_loss = tl.norm(rec - cumulant, 2)**2
    ortho_loss = (1 + theta)/2*tl.norm(rec, 2)**2
    return ortho_loss + rec_loss, ortho_loss, rec_loss

def correlate(a, v):
    a = cp.asarray((a - tl.mean(a)) / (np.std(a, dtype=np.float64) * len(a)))
    v = cp.asarray((v - tl.mean(v)) /  np.std(v, dtype=np.float64))
    return np.correlate(a, v)

def create_data(vocab= 500, seed= None):
    num_tops   = 2
    num_tweets = 20000
    density    = 15
    vocab      = vocab # 100
    # smoothing  = 0.001 #1e-5 #1e-5
    seed       = seed

    print("Vocab: " + str(vocab))
    print("num_tweets: " + str(num_tweets))
    print("density: " + str(density))

    '''get and whiten the data'''
    x = None
    while x is None:
        try:
            x, mu, _, alpha_0 = get_mu(num_tops, vocab, num_tweets, density, seed)
        except ValueError:
            pass
 
    pickle.dump(x, open(f'data/synthetic_data_{seed}_{vocab}.obj', 'wb'))
    pickle.dump(alpha_0, open(f'data/true_alpha_{seed}_{vocab}.obj', 'wb') )
    pickle.dump(mu, open(f'data/true_mu_{seed}_{vocab}.obj', 'wb'))

def postprocess(factors_unwhitened, mu, x, vocab, num_tops, smoothing, decenter=False, postprocess=True, name="", alpha_0 = 1):
    '''Post-Processing '''
    res = []
    # Postprocessing
    #This is hard-coded. We should calculate the alphas by hand. 
    if postprocess == True:
        if decenter == True:
            t1 = time.time()
            wc = np.asarray(tl.mean(x, axis=0))
            wc   =  tl.reshape(wc,(vocab,1))
            
            factors_unwhitened = np.asarray(factors_unwhitened)
            factors_unwhitened += wc
            t2 = time.time()

    
        factors_unwhitened = np.asarray(factors_unwhitened)
        t1 = time.time()
        factors_unwhitened [factors_unwhitened  < 0.] = 0.
        # smooth beta
        factors_unwhitened  *= (1. - smoothing)

        factors_unwhitened += (smoothing / factors_unwhitened.shape[1])
        factors_unwhitened /= factors_unwhitened.sum(axis=0)
        t2 = time.time()
        print("Smoothing and Normalization: " + str(t2-t1))
        res.append((name + ' smoothing and normalization', t2-t1))

    '''test RMSE'''
    mu = np.asarray(mu)[:, 0, :]
    permutation,RMSE = validate_gammad(factors_unwhitened.T, mu, num_tops = num_tops)
    print("Fit RMSE: " + str(RMSE.item()))
    print(name + " Test Against Ground Truth")

    outFile = open("results/accuracies"+name+".txt", 'w')

    print(mu.shape, file=outFile)
    print(mu[permutation[1]].shape, file=outFile)

    accuracy = []
    if decenter == False:
        for i in range(num_tops):
            accuracy.append(correlate(factors_unwhitened.T[i,:], mu[permutation[1]][i,:]))
    else:
        for i in range(num_tops):
            accuracy.append(correlate(factors_unwhitened.T[i,:], mu[permutation[1]][i,:]))
            
    return res, accuracy

# Get correlation for 2 previous tensor LDA approaches
def gen_fit_0_15(n_iter_max=200, vocab=VOCAB, theta=1, learning_rate = 0.01, seed=None):
    num_tops = 2
    vocab   = vocab
    theta = theta
    learning_rate = learning_rate
    seed = seed
    smoothing  =  1e-5

    res = []
    x       = pickle.load( open(f'data/synthetic_data_{seed}_{vocab}.obj', 'rb'))
    alpha_0 = pickle.load( open(f'data/true_alpha_{seed}_{vocab}.obj', 'rb') )
    mu      = pickle.load( open(f'data/true_mu_{seed}_{vocab}.obj', 'rb'))

    backend="numpy"
    tl.set_backend(backend)
    
    x = tl.tensor(x)

    t1 = time.time()
    M1 = tlda_mid.get_M1(x)
    t2 = time.time()
    print("M1: " + str(t2-t1))
    res.append(('M1 calc', t2-t1))

    t1 = time.time()
    M2_img = tlda_mid.get_M2(x, M1, alpha_0)
    t2 = time.time()
    print("M2: " + str(t2-t1))
    res.append(('M2 calc', t2-t1))

    t1 = time.time()
    W, W_inv = tlda_mid.whiten(M2_img, num_tops) # W (n_words x n_topics)
    t2 = time.time()
    print(tl.dot(tl.dot(W.T, M2_img), W))
    print("W: " + str(t2-t1))
    res.append(('W calc', t2-t1))

    W = tl.tensor(W)
    W_inv = tl.tensor(W_inv)

    t1 = time.time()
    X_whitened = tl.dot(x, W)
    t2 = time.time()
    print("Whiten X: " + str(t2-t1))
    res.append(('whiten X', t2-t1))

    res_copy = res.copy()

    # This is where the two versions branch off -- begin with version 0.10
    t1 = time.time()
    M1_whitened = tl.dot(M1, W)
    t2 = time.time()
    print("Whiten M1: " + str(t2-t1))
    res.append(('whiten M1', t2-t1))

    t1 = time.time()
    M3_final = tlda_mid.get_M3(X_whitened, M1_whitened, alpha_0)
    t2 = time.time()
    print("Parafac M3: " + str(t2-t1))
    res.append(('construct M3', t2-t1))

    t1 = time.time()
    lambdas_learned_parafac, phis_learned_parafac = sym_parafac(M3_final, rank=num_tops, n_repeat=100, n_iteration=50, verbose=False)
    t2 = time.time()
    print("Parafac Decomposition: " + str(t2-t1))
    res.append(('decompose parafac', t2-t1))

    t1 = time.time()
    factors_unwhitened_parafac     = (tl.dot(W_inv,phis_learned_parafac )) 
    t2 = time.time()
    print("Unwhitening parafac factors: " + str(t2-t1))
    res.append(('unwhiten factors parafac', t2-t1))

    t1 = time.time()
    weights, phis_learned  = tlda_mid.simulate_all(X_whitened, alpha_0, num_tops, lr1 = learning_rate, theta=theta, seed=seed, verbose = False,min_iter = 10,max_iter=100)#n_iter_max)
    t2 = time.time()
    print("SGD Calc: " + str(t2-t1))
    res_copy.append(('SGD calc', t2-t1))

    t1 = time.time()
    factors_unwhitened     = (tl.dot(W_inv,phis_learned )) 
    t2 = time.time()
    print("Unwhitening factors: " + str(t2-t1))
    res_copy.append(('unwhiten factors SGD', t2-t1))

    res3, accuracy_parafac = postprocess(factors_unwhitened_parafac, mu, x, vocab, num_tops, smoothing, decenter=False, name="parafac")
    res2, accuracy_uncentered = postprocess(factors_unwhitened, mu, x, vocab, num_tops, smoothing, decenter=False)
    res.extend(res3)
    res_copy.extend(res2)

    tot_t1 = 0.0
    for (_, t) in res:
        tot_t1 += t
    res = [("total PARAFAC time", tot_t1)]

    tot_t2 = 0.0
    for (_, t) in res_copy:
        tot_t2 += t
    res_copy = [("total SGD time", tot_t2)]
    return res, res_copy, np.mean(np.array([a.item() for a in accuracy_parafac])), np.mean(np.array([a.item() for a in accuracy_uncentered]))


# Get correlation for our TLDA approach
def gen_fit_0_20(n_iter_train = 2001, batch_size_grad= 100, vocab = VOCAB, theta=1, learning_rate = 0.01, seed=None):
    num_tops = 2
    vocab   = vocab
    n_iter_train     = n_iter_train
    batch_size_pca =  20000
    batch_size_grad  = batch_size_grad 
    n_iter_test = 10 
    theta_param = theta 
    learning_rate = learning_rate
    smoothing  = 1e-5 

    res = []
    x       = pickle.load( open(f'data/synthetic_data_{seed}_{vocab}.obj', 'rb'))
    alpha_0 = pickle.load( open(f'data/true_alpha_{seed}_{vocab}.obj', 'rb') )
    mu      = pickle.load( open(f'data/true_mu_{seed}_{vocab}.obj', 'rb'))

    backend="numpy"
    tl.set_backend(backend)
    
    x = tl.tensor(x)
    tlda = TLDA(
        num_tops, alpha_0, 
        n_iter_train,n_iter_test,
        learning_rate, 
        pca_batch_size= batch_size_pca,
        third_order_cumulant_batch = batch_size_grad, 
        gamma_shape = 1.0, smoothing = smoothing, 
        theta=theta_param, random_seed = seed)


    t1 = time.time()
    tlda.fit(x)
    factors_unwhitened = tlda.transform(x, predict=False)
    t2 = time.time()
    print("TLDA time: " + str(t2-t1))
    res.append(('TLDA', t2-t1))
    
    res2, accuracy = postprocess(factors_unwhitened, mu, x, vocab, num_tops, smoothing, True, postprocess=False, alpha_0 = alpha_0)
    # if res2 != []:
    # res.extend(res2)

    res = [('total batched TLDA time', res[0][1])]
    return res, np.mean(np.array([a.item() for a in accuracy])) #, res3, accuracy2


def main():
    print("new version")
    nums = 10
    tot_parafac = {}
    tot_uncentered = {}
    tot_centered = {}
    acc_parafac = [] 
    acc_uncentered = []
    acc_centered = []
    vocab_arr = [500, 1000, 1500]
    theta_arr = [10]
    lr_arr = [1e-4]
    j = 0
    
    seed_arr =  [4068186562, 2293672821, 899886193, 511320915, 133031152, 3156287835, 133148045, 823517892, 1864981031, 2544549694, 2488735442, 2760984162, 82373710, 82366915, 2228613541, 3486018369, 970599314, 598791238, 1254893416, 622627511, 138100630, 4027698360, 3942538484, 189181519, 1743605279, 1053334122, 756433468, 3412400620, 4121758197, 3220859515]
    for i in range(0, nums):
        for vocab in vocab_arr:
            create_data(vocab=vocab, seed=seed_arr[j])
            print('created data!')
            backend = "numpy"
            tl.set_backend(backend)
            for lr in lr_arr:
                for theta in theta_arr:
                    res_parafac, res_uncentered, accuracy_parafac, accuracy_uncentered = gen_fit_0_15(vocab=vocab, seed=seed_arr[j], theta=theta, learning_rate=1e-4)
                    res_centered, accuracy_centered = gen_fit_0_20(n_iter_train = 2001, vocab=vocab, theta=theta, learning_rate=1e-2, seed=seed_arr[j]) 

                    acc_parafac.append(accuracy_parafac)
                    acc_centered.append(accuracy_centered)
                    acc_uncentered.append(accuracy_uncentered)
                    if i == 0:
                        tot_parafac[vocab] = res_parafac[0]
                        tot_uncentered[vocab] = res_uncentered[0]
                        tot_centered[vocab] = res_centered[0]
                    else:
                        tot_parafac[vocab] = (tot_parafac[vocab][0], tot_parafac[vocab][1] + res_parafac[0][1]) #[(x, y + res_parafac[i][1]) for i, (x, y) in enumerate(tot_parafac[vocab])]
                        tot_uncentered[vocab] = (tot_uncentered[vocab][0], tot_uncentered[vocab][1] + res_uncentered[0][1]) #[(x, y + res_uncentered[i][1]) for i, (x, y) in enumerate(tot_uncentered[vocab])]
                        tot_centered[vocab] = (tot_centered[vocab][0], tot_centered[vocab][1] + res_centered[0][1])#[(x, y + res_centered[i][1]) for i, (x, y) in enumerate(tot_centered[vocab])]
        
        j += 1

        print(acc_parafac)
        print(acc_centered)
        print(acc_uncentered)

    for vocab in vocab_arr:
        tot_parafac[vocab] = (tot_parafac[vocab][0], tot_parafac[vocab][1]/nums)
        tot_uncentered[vocab] = (tot_uncentered[vocab][0], tot_uncentered[vocab][1]/nums)
        tot_centered[vocab] = (tot_centered[vocab][0], tot_centered[vocab][1]/nums)

    with open(f'results/results_20k_reproduced.csv', 'w') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(["Vocab", "Step", "Time (s)"])
        for vocab in vocab_arr:
            for result_arr in [tot_parafac[vocab], tot_uncentered[vocab], tot_centered[vocab]]: 
                csvwriter.writerow([vocab, result_arr[0], result_arr[1]])
    
    with open(f'results/results_correlation_20k_reproduced.csv', 'w') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(["Iteration", "Vocabulary Size", "learning rate", "theta", "Correlation Parafac", "Correlation SGD TLDA", "Correlation Batched TLDA"])
        j = 0
        # for i in range (0, nums):
        for vocab in vocab_arr:
            for lr in lr_arr:
                for theta in theta_arr:
                    csvwriter.writerow([str(i), str(vocab), str(lr), str(theta), str(sum([a for k, a in enumerate(acc_parafac) if k % len(vocab_arr) == j])/nums), str(sum([a for k, a in enumerate(acc_uncentered) if k % len(vocab_arr) == j])/nums), str(sum([a for k, a in enumerate(acc_centered) if k % len(vocab_arr) == j])/nums)])
            j += 1

    print("Done!")
    return

if __name__ == '__main__':
    main()
